package org.etnaframework.module.base.utils;

import org.etnaframework.module.base.config.ConstantKeys;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

/**
 * CRON表达式执行辅助工具
 *
 * @author jasyaf
 * @since 2023-11-17
 */
public class CronExecutorUtils {

    private static final Logger log = LoggerFactory.getLogger(CronExecutorUtils.class);

    private static final ScheduledExecutorService DEFAULT_DISPATCHER = Executors.newScheduledThreadPool(Runtime.getRuntime().availableProcessors(), new NamedThreadFactory("DispatchUCron::", Thread.MAX_PRIORITY, true));

    private static final ExecutorService DEFAULT_EXECUTOR = Executors.newCachedThreadPool(new NamedThreadFactory("UCron::", Thread.NORM_PRIORITY, true));

    private static final ThreadLocal<Long> thisTimestamp = new ThreadLocal<>();

    /**
     * 提交用CRON表达式描述的定时任务，提交时不会立即执行，使用内置的线程池组件
     *
     * @param cron 支持6个参数的表达式：秒 分钟 小时 日期 月 周，参见{@link CronGenerator}
     */
    public static CronTaskTrigger schedule(String cron, Runnable runnable) {
        CronTaskTrigger task = new CronTaskTrigger(DEFAULT_DISPATCHER, DEFAULT_EXECUTOR, cron, runnable, 0);
        DEFAULT_DISPATCHER.execute(new CronTaskTriggerRun(task));
        return task;
    }

    /**
     * 提交用CRON表达式描述的定时任务，提交时不会立即执行
     *
     * @param dispatcher 用于定时任务分发的线程池
     * @param executor   用于实际执行任务的线程池
     * @param cron       支持6个参数的表达式：秒 分钟 小时 日期 月 周，参见{@link CronGenerator}
     */
    public static CronTaskTrigger schedule(ScheduledExecutorService dispatcher, ExecutorService executor, String cron, Runnable runnable) {
        CronTaskTrigger task = new CronTaskTrigger(dispatcher, executor, cron, runnable, 0);
        dispatcher.execute(new CronTaskTriggerRun(task));
        return task;
    }

    /**
     * 提交用CRON表达式描述的定时任务，提交时不会立即执行
     *
     * @param dispatcher 用于定时任务分发的线程池
     * @param executor   用于实际执行任务的线程池
     * @param delayMs    执行时延迟毫秒数，取值范围[0,1000)，当有多个定时任务时可以错开触发点，避免同时在整数秒的时候触发对相关上下游资源造成压力
     * @param cron       支持6个参数的表达式：秒 分钟 小时 日期 月 周，参见{@link CronGenerator}
     */
    public static CronTaskTrigger scheduleWithDelay(ScheduledExecutorService dispatcher, ExecutorService executor, String cron, long delayMs, Runnable runnable) {
        if (delayMs < 0 || delayMs >= 1000) {
            throw new IllegalArgumentException("delayMs ranges in [0,1000) and value " + delayMs + " is invalid");
        }
        CronTaskTrigger task = new CronTaskTrigger(dispatcher, executor, cron, runnable, delayMs);
        dispatcher.execute(new CronTaskTriggerRun(task));
        return task;
    }

    /**
     * 提交用CRON表达式描述的定时任务，提交时不会立即执行，使用内置的线程池组件
     *
     * @param delayMs 执行时延迟毫秒数，取值范围[0,1000)，当有多个定时任务时可以错开触发点，避免同时在整数秒的时候触发对相关上下游资源造成压力
     * @param cron    支持6个参数的表达式：秒 分钟 小时 日期 月 周，参见{@link CronGenerator}
     */
    public static CronTaskTrigger scheduleWithDelay(String cron, long delayMs, Runnable runnable) {
        if (delayMs < 0 || delayMs >= 1000) {
            throw new IllegalArgumentException("delayMs ranges in [0,1000) and value " + delayMs + " is invalid");
        }
        CronTaskTrigger task = new CronTaskTrigger(DEFAULT_DISPATCHER, DEFAULT_EXECUTOR, cron, runnable, delayMs);
        DEFAULT_DISPATCHER.execute(new CronTaskTriggerRun(task));
        return task;
    }

    /**
     * 获取当前执行的定时任务的预定开始时间，可用于标记是执行哪一次定时任务
     */
    public static long getThisTimestamp() {
        return Optional.ofNullable(thisTimestamp.get()).orElse(0L);
    }

    public static class CronTaskTrigger {

        final ScheduledExecutorService dispatcher;

        final ExecutorService executor;

        final CronGenerator cronGenerator;

        final String traceIdSuffix;

        final Runnable runnable;

        final long delayMs;

        volatile long nextStartTime;

        volatile boolean stop;

        CronTaskTrigger(ScheduledExecutorService dispatcher, ExecutorService executor, String cron, Runnable runnable, long delayMs) {
            this.dispatcher = dispatcher;
            this.executor = executor;
            this.cronGenerator = new CronGenerator(cron);
            this.traceIdSuffix = Integer.toHexString(cron.hashCode());
            this.runnable = runnable;
            this.delayMs = delayMs;
            this.nextStartTime = 0;
            this.stop = false;
        }

        public String getCron() {
            return this.cronGenerator.getExpression();
        }

        public long getDelayMs() {
            return delayMs;
        }

        /**
         * 停止未来的任务，注意不会导致当前正在执行的任务停止
         */
        public void stop() {
            this.stop = true;
        }
    }

    private static class CronTaskTriggerRun implements Runnable {

        final CronTaskTrigger trigger;

        CronTaskTriggerRun(CronTaskTrigger trigger) {
            this.trigger = trigger;
        }

        @Override
        public void run() {

            if (trigger.stop) {
                return;
            }

            long thisTimeTag = trigger.nextStartTime;

            // 系统时间校正时，如果本地时间比标准时间快，校正时时间往后退，如果退到了lastStartTime之前，计算下次任务时又会算到这个时间点，导致已执行的任务再次被触发执行
            // 这里把上次计算的nextStartTime（实质是预计本次执行的时间）和当前时间进行比对，取较靠后者，确保开始执行时间不能回退，避免上述bug
            long now = Math.max(System.currentTimeMillis(), trigger.nextStartTime);
            trigger.nextStartTime = trigger.cronGenerator.next(now) + trigger.delayMs;

            // dispatcher线程池负责任务调度，先将下次执行的任务提交，确保下次任务能按时触发，然后在executor线程开始本次任务
            trigger.dispatcher.schedule(this, Math.max(trigger.nextStartTime - now, 0), TimeUnit.MILLISECONDS);

            if (thisTimeTag <= 0) {
                return;
            }
            trigger.executor.execute(() -> {
                try {
                    thisTimestamp.set(thisTimeTag);
                    MDC.put(ConstantKeys.TAG_TRACE_ID, DatetimeUtils.format(thisTimeTag, "yyyyMMdd_HHmmss") + "_" + trigger.traceIdSuffix); // 请求开始时分配traceId（用于追踪请求链路）
                    trigger.runnable.run();
                } catch (Throwable ex) {
                    log.error("", ex);
                } finally {
                    thisTimestamp.remove();
                    MDC.clear();
                }
            });
        }
    }

    /**
     * cron表达式解析工具，源自Spring的CronSequenceGenerator
     *
     * <pre>
     * 支持6个参数的表达式：秒 分钟 小时 日期 月 周
     *
     * 0 0 * * * *" = the top of every hour of every day.
     * *&#47;10 * * * * *" = every ten seconds.
     * 0 0 8-10 * * *" = 8, 9 and 10 o'clock of every day.
     * 0 0 6,19 * * *" = 6:00 AM and 7:00 PM every day.
     * 0 0/30 8-10 * * *" = 8:00, 8:30, 9:00, 9:30, 10:00 and 10:30 every day.
     * 0 0 9-17 * * MON-FRI" = on the hour nine-to-five weekdays
     * 0 0 0 25 12 ?" = every Christmas Day at midnight
     * </pre>
     */
    private static class CronGenerator {

        private final String expression;

        @Nullable
        private final TimeZone timeZone;

        private final BitSet months = new BitSet(12);

        private final BitSet daysOfMonth = new BitSet(31);

        private final BitSet daysOfWeek = new BitSet(7);

        private final BitSet hours = new BitSet(24);

        private final BitSet minutes = new BitSet(60);

        private final BitSet seconds = new BitSet(60);

        /**
         * Construct a {@code CronSequenceGenerator} from the pattern provided,
         * using the default {@link TimeZone}.
         *
         * @param expression a space-separated list of time fields
         * @throws IllegalArgumentException if the pattern cannot be parsed
         * @see TimeZone#getDefault()
         */
        public CronGenerator(String expression) {
            this(expression, TimeZone.getDefault());
        }

        /**
         * Construct a {@code CronSequenceGenerator} from the pattern provided,
         * using the specified {@link TimeZone}.
         *
         * @param expression a space-separated list of time fields
         * @param timeZone   the TimeZone to use for generated trigger times
         * @throws IllegalArgumentException if the pattern cannot be parsed
         */
        public CronGenerator(String expression, TimeZone timeZone) {
            this.expression = expression;
            this.timeZone = timeZone;
            parse(expression);
        }

        private CronGenerator(String expression, String[] fields) {
            this.expression = expression;
            this.timeZone = null;
            doParse(fields);
        }

        /**
         * Return the cron pattern that this sequence generator has been built for.
         */
        String getExpression() {
            return this.expression;
        }

        /**
         * Get the next {@link Date} in the sequence matching the Cron pattern and
         * after the value provided. The return value will have a whole number of
         * seconds, and will be after the input value.
         *
         * @param date a seed value
         * @return the next value matching the pattern
         */
        public long next(long date) {
            /*
            The plan:

            1 Start with whole second (rounding up if necessary)

            2 If seconds match move on, otherwise find the next match:
            2.1 If next match is in the next minute then roll forwards

            3 If minute matches move on, otherwise find the next match
            3.1 If next match is in the next hour then roll forwards
            3.2 Reset the seconds and go to 2

            4 If hour matches move on, otherwise find the next match
            4.1 If next match is in the next day then roll forwards,
            4.2 Reset the minutes and seconds and go to 2
            */

            Calendar calendar = new GregorianCalendar();
            calendar.setTimeZone(this.timeZone);
            calendar.setTime(new Date(date));

            // First, just reset the milliseconds and try to calculate from there...
            calendar.set(Calendar.MILLISECOND, 0);
            long originalTimestamp = calendar.getTimeInMillis();
            doNext(calendar, calendar.get(Calendar.YEAR));

            if (calendar.getTimeInMillis() == originalTimestamp) {
                // We arrived at the original timestamp - round up to the next whole second and try again...
                calendar.add(Calendar.SECOND, 1);
                doNext(calendar, calendar.get(Calendar.YEAR));
            }

            return calendar.getTime().getTime();
        }

        private void doNext(Calendar calendar, int dot) {
            List<Integer> resets = new ArrayList<>();

            int second = calendar.get(Calendar.SECOND);
            List<Integer> emptyList = Collections.emptyList();
            int updateSecond = findNext(this.seconds, second, calendar, Calendar.SECOND, Calendar.MINUTE, emptyList);
            if (second == updateSecond) {
                resets.add(Calendar.SECOND);
            }

            int minute = calendar.get(Calendar.MINUTE);
            int updateMinute = findNext(this.minutes, minute, calendar, Calendar.MINUTE, Calendar.HOUR_OF_DAY, resets);
            if (minute == updateMinute) {
                resets.add(Calendar.MINUTE);
            } else {
                doNext(calendar, dot);
            }

            int hour = calendar.get(Calendar.HOUR_OF_DAY);
            int updateHour = findNext(this.hours, hour, calendar, Calendar.HOUR_OF_DAY, Calendar.DAY_OF_WEEK, resets);
            if (hour == updateHour) {
                resets.add(Calendar.HOUR_OF_DAY);
            } else {
                doNext(calendar, dot);
            }

            int dayOfWeek = calendar.get(Calendar.DAY_OF_WEEK);
            int dayOfMonth = calendar.get(Calendar.DAY_OF_MONTH);
            int updateDayOfMonth = findNextDay(calendar, this.daysOfMonth, dayOfMonth, this.daysOfWeek, dayOfWeek, resets);
            if (dayOfMonth == updateDayOfMonth) {
                resets.add(Calendar.DAY_OF_MONTH);
            } else {
                doNext(calendar, dot);
            }

            int month = calendar.get(Calendar.MONTH);
            int updateMonth = findNext(this.months, month, calendar, Calendar.MONTH, Calendar.YEAR, resets);
            if (month != updateMonth) {
                if (calendar.get(Calendar.YEAR) - dot > 4) {
                    throw new IllegalArgumentException("Invalid cron expression \"" + this.expression + "\" led to runaway search for next trigger");
                }
                doNext(calendar, dot);
            }
        }

        private int findNextDay(Calendar calendar, BitSet daysOfMonth, int dayOfMonth, BitSet daysOfWeek, int dayOfWeek, List<Integer> resets) {

            int count = 0;
            int max = 366;
            // the DAY_OF_WEEK values in java.util.Calendar start with 1 (Sunday),
            // but in the cron pattern, they start with 0, so we subtract 1 here
            while ((!daysOfMonth.get(dayOfMonth) || !daysOfWeek.get(dayOfWeek - 1)) && count++ < max) {
                calendar.add(Calendar.DAY_OF_MONTH, 1);
                dayOfMonth = calendar.get(Calendar.DAY_OF_MONTH);
                dayOfWeek = calendar.get(Calendar.DAY_OF_WEEK);
                reset(calendar, resets);
            }
            if (count >= max) {
                throw new IllegalArgumentException("Overflow in day for expression \"" + this.expression + "\"");
            }
            return dayOfMonth;
        }

        /**
         * Search the bits provided for the next set bit after the value provided,
         * and reset the calendar.
         *
         * @param bits        a {@link BitSet} representing the allowed values of the field
         * @param value       the current value of the field
         * @param calendar    the calendar to increment as we move through the bits
         * @param field       the field to increment in the calendar (@see
         *                    {@link Calendar} for the static constants defining valid fields)
         * @param lowerOrders the Calendar field ids that should be reset (i.e. the
         *                    ones of lower significance than the field of interest)
         * @return the value of the calendar field that is next in the sequence
         */
        private int findNext(BitSet bits, int value, Calendar calendar, int field, int nextField, List<Integer> lowerOrders) {
            int nextValue = bits.nextSetBit(value);
            // roll over if needed
            if (nextValue == -1) {
                calendar.add(nextField, 1);
                reset(calendar, Collections.singletonList(field));
                nextValue = bits.nextSetBit(0);
            }
            if (nextValue != value) {
                calendar.set(field, nextValue);
                reset(calendar, lowerOrders);
            }
            return nextValue;
        }

        /**
         * Reset the calendar setting all the fields provided to zero.
         */
        private void reset(Calendar calendar, List<Integer> fields) {
            for (int field : fields) {
                calendar.set(field, field == Calendar.DAY_OF_MONTH ? 1 : 0);
            }
        }

        // Parsing logic invoked by the constructor

        /**
         * Parse the given pattern expression.
         */
        private void parse(String expression) throws IllegalArgumentException {
            String[] fields = StringUtils.tokenizeToStringArray(expression, " ");
            if (!areValidCronFields(fields)) {
                throw new IllegalArgumentException(String.format("Cron expression must consist of 6 fields (found %d in \"%s\")", fields.length, expression));
            }
            doParse(fields);
        }

        private void doParse(String[] fields) {
            setNumberHits(this.seconds, fields[0], 0, 60);
            setNumberHits(this.minutes, fields[1], 0, 60);
            setNumberHits(this.hours, fields[2], 0, 24);
            setDaysOfMonth(this.daysOfMonth, fields[3]);
            setMonths(this.months, fields[4]);
            setDays(this.daysOfWeek, replaceOrdinals(fields[5], "SUN,MON,TUE,WED,THU,FRI,SAT"), 8);

            if (this.daysOfWeek.get(7)) {
                // Sunday can be represented as 0 or 7
                this.daysOfWeek.set(0);
                this.daysOfWeek.clear(7);
            }
        }

        /**
         * Replace the values in the comma-separated list (case insensitive)
         * with their index in the list.
         *
         * @return a new String with the values from the list replaced
         */
        private String replaceOrdinals(String value, String commaSeparatedList) {
            String[] list = StringUtils.commaDelimitedListToStringArray(commaSeparatedList);
            for (int i = 0; i < list.length; i++) {
                String item = list[i].toUpperCase();
                value = StringUtils.replace(value.toUpperCase(), item, "" + i);
            }
            return value;
        }

        private void setDaysOfMonth(BitSet bits, String field) {
            int max = 31;
            // Days of month start with 1 (in Cron and Calendar) so add one
            setDays(bits, field, max + 1);
            // ... and remove it from the front
            bits.clear(0);
        }

        private void setDays(BitSet bits, String field, int max) {
            if (field.contains("?")) {
                field = "*";
            }
            setNumberHits(bits, field, 0, max);
        }

        private void setMonths(BitSet bits, String value) {
            int max = 12;
            value = replaceOrdinals(value, "FOO,JAN,FEB,MAR,APR,MAY,JUN,JUL,AUG,SEP,OCT,NOV,DEC");
            BitSet months = new BitSet(13);
            // Months start with 1 in Cron and 0 in Calendar, so push the values first into a longer bit set
            setNumberHits(months, value, 1, max + 1);
            // ... and then rotate it to the front of the months
            for (int i = 1; i <= max; i++) {
                if (months.get(i)) {
                    bits.set(i - 1);
                }
            }
        }

        private void setNumberHits(BitSet bits, String value, int min, int max) {
            String[] fields = StringUtils.delimitedListToStringArray(value, ",");
            for (String field : fields) {
                if (!field.contains("/")) {
                    // Not an incrementer so it must be a range (possibly empty)
                    int[] range = getRange(field, min, max);
                    bits.set(range[0], range[1] + 1);
                } else {
                    String[] split = StringUtils.delimitedListToStringArray(field, "/");
                    if (split.length > 2) {
                        throw new IllegalArgumentException("Incrementer has more than two fields: '" + field + "' in expression \"" + this.expression + "\"");
                    }
                    int[] range = getRange(split[0], min, max);
                    if (!split[0].contains("-")) {
                        range[1] = max - 1;
                    }
                    int delta = Integer.parseInt(split[1]);
                    if (delta <= 0) {
                        throw new IllegalArgumentException("Incrementer delta must be 1 or higher: '" + field + "' in expression \"" + this.expression + "\"");
                    }
                    for (int i = range[0]; i <= range[1]; i += delta) {
                        bits.set(i);
                    }
                }
            }
        }

        private int[] getRange(String field, int min, int max) {
            int[] result = new int[2];
            if (field.contains("*")) {
                result[0] = min;
                result[1] = max - 1;
                return result;
            }
            if (!field.contains("-")) {
                result[0] = result[1] = Integer.parseInt(field);
            } else {
                String[] split = StringUtils.delimitedListToStringArray(field, "-");
                if (split.length > 2) {
                    throw new IllegalArgumentException("Range has more than two fields: '" + field + "' in expression \"" + this.expression + "\"");
                }
                result[0] = Integer.parseInt(split[0]);
                result[1] = Integer.parseInt(split[1]);
            }
            if (result[0] >= max || result[1] >= max) {
                throw new IllegalArgumentException("Range exceeds maximum (" + max + "): '" + field + "' in expression \"" + this.expression + "\"");
            }
            if (result[0] < min || result[1] < min) {
                throw new IllegalArgumentException("Range less than minimum (" + min + "): '" + field + "' in expression \"" + this.expression + "\"");
            }
            if (result[0] > result[1]) {
                throw new IllegalArgumentException("Invalid inverted range: '" + field + "' in expression \"" + this.expression + "\"");
            }
            return result;
        }

        /**
         * Determine whether the specified expression represents a valid cron pattern.
         *
         * @param expression the expression to evaluate
         * @return {@code true} if the given expression is a valid cron expression
         * @since 4.3
         */
        public static boolean isValidExpression(@Nullable String expression) {
            if (expression == null) {
                return false;
            }
            String[] fields = StringUtils.tokenizeToStringArray(expression, " ");
            if (!areValidCronFields(fields)) {
                return false;
            }
            try {
                new CronGenerator(expression, fields);
                return true;
            } catch (IllegalArgumentException ex) {
                return false;
            }
        }

        private static boolean areValidCronFields(@Nullable String[] fields) {
            return (fields != null && fields.length == 6);
        }

        @Override
        public boolean equals(@Nullable Object other) {
            if (this == other) {
                return true;
            }
            if (!(other instanceof CronGenerator)) {
                return false;
            }
            CronGenerator otherCron = (CronGenerator) other;
            return (this.months.equals(otherCron.months) && this.daysOfMonth.equals(otherCron.daysOfMonth) && this.daysOfWeek.equals(otherCron.daysOfWeek) && this.hours.equals(otherCron.hours) && this.minutes.equals(otherCron.minutes) && this.seconds.equals(otherCron.seconds));
        }

        @Override
        public int hashCode() {
            return (17 * this.months.hashCode() + 29 * this.daysOfMonth.hashCode() + 37 * this.daysOfWeek.hashCode() + 41 * this.hours.hashCode() + 53 * this.minutes.hashCode() + 61 * this.seconds.hashCode());
        }

        @Override
        public String toString() {
            return getClass().getSimpleName() + ": " + this.expression;
        }
    }
}
